{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 线性回归 — 从0开始"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "参考: http://zh.gluon.ai/chapter_supervised-learning/linear-regression-scratch.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/D/Ubuntu/package/anaconda3/lib/python3.6/site-packages/urllib3/contrib/pyopenssl.py:46: DeprecationWarning: OpenSSL.rand is deprecated - you should use os.urandom instead\n  import OpenSSL.SSL\n"
     ]
    }
   ],
   "source": [
    "# 这里噪音服从均值0和标准差为0.01的正态分布。\n",
    "\n",
    "from mxnet import ndarray as nd\n",
    "from mxnet import autograd\n",
    "\n",
    "num_inputs = 2\n",
    "num_examples = 1000\n",
    "\n",
    "true_w = [2, -3.4]\n",
    "true_b = 4.2\n",
    "\n",
    "X = nd.random_normal(shape=(num_examples, num_inputs))\n",
    "y = true_w[0] * X[:, 0] + true_w[1] * X[:, 1] + true_b\n",
    "y += .01 * nd.random_normal(shape=y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1000, 2)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = nd.dot(X, nd.array(true_w).T) + true_b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n[ 4.66994476]\n<NDArray 1 @cpu(0)>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "y += .01 * nd.random_normal(shape=y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n[ 4.6771903]\n<NDArray 1 @cpu(0)>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y[0]"
   ]
  },
  {
   "cell_type": "heading",
   "metadata": {},
   "level": 1,
   "source": [
    "## 数据读取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import random\n",
    "batch_size = 10\n",
    "\n",
    "\n",
    "def data_iter():\n",
    "    # 产生一个随机索引\n",
    "    idx = list(range(num_examples))\n",
    "    random.shuffle(idx)\n",
    "    for i in range(0, num_examples, batch_size):\n",
    "        j = nd.array(idx[i: min(i + batch_size, num_examples)])\n",
    "        yield nd.take(X, j), nd.take(y, j)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n[[-0.56938112  1.07255459]\n [-0.69402587 -0.0323206 ]\n [ 0.37170771 -0.55021369]\n [ 0.02715491 -0.26509324]\n [-1.03527749  1.69480944]\n [-1.02182448  2.19352984]\n [-0.73192883 -0.50927299]\n [ 0.32510808 -0.72485566]\n [-1.20509708 -1.76932311]\n [-0.46041864  1.80271161]]\n<NDArray 10x2 @cpu(0)> \n[-0.59124488  2.91981149  6.81410265  5.1549077  -3.63820553 -5.30337906\n  4.47576666  7.31307507  7.80073404 -2.84838247]\n<NDArray 10 @cpu(0)>\n"
     ]
    }
   ],
   "source": [
    "for data, label in data_iter():\n",
    "    print(data, label)\n",
    "    break"
   ]
  },
  {
   "cell_type": "heading",
   "metadata": {},
   "level": 2,
   "source": [
    "初始化模型参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "w = nd.random_normal(shape=(num_inputs, 1))\n",
    "b = nd.zeros((1,))\n",
    "params = [w, b]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "for param in params:\n",
    "    param.attach_grad()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 定义模型\n",
    "def net(X):\n",
    "    return nd.dot(X, w) + b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "def square_loss(yhat, y):\n",
    "    # 注意这里我们把y变形成yhat的形状来避免自动广播\n",
    "    return nd.sum((yhat - y.reshape(yhat.shape)) ** 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def SGD(params, lr):\n",
    "    for param in params:\n",
    "        param[:] = param - lr * param.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1,)\n\n[[-1.69479954]\n [ 1.87601256]]\n<NDArray 2x1 @cpu(0)>\n\n[-1.72051334]\n<NDArray 1 @cpu(0)>\nEpoch 0, average loss: 0.000124\n"
     ]
    }
   ],
   "source": [
    "epochs = 1\n",
    "learning_rate = .001\n",
    "for e in range(epochs):\n",
    "    total_loss = 0\n",
    "    for data, label in data_iter():\n",
    "        with autograd.record():\n",
    "            output = net(data)\n",
    "            loss = square_loss(output, label)\n",
    "        loss.backward()\n",
    "        print(square_loss(output, label).shape)\n",
    "        print(params[0].grad)\n",
    "        print(params[1].grad)\n",
    "        SGD(params, learning_rate)\n",
    "\n",
    "        total_loss += nd.sum(loss).asscalar()\n",
    "        break\n",
    "    print(\"Epoch %d, average loss: %f\" % (e, total_loss/num_examples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([2, -3.4], \n [[ 1.99997067]\n  [-3.39932537]]\n <NDArray 2x1 @cpu(0)>)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "true_w, w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4.2, \n [ 4.19992447]\n <NDArray 1 @cpu(0)>)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "true_b, b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = nd.array([1, 2, 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "b = nd.array((2, 3, 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n[ 2.  3.  5.]\n<NDArray 3 @cpu(0)>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n[ 1.  1.  4.]\n<NDArray 3 @cpu(0)>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(a - b) ** 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n[ 1.  1.  4.]\n<NDArray 3 @cpu(0)>"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "square_loss(a, b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
